import os
import argparse
import torch
import time
import random
import wandb
import colossalai
import torch.distributed as dist
import numpy as np
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from transformers import GPT2Tokenizer, DataCollatorForLanguageModeling
from datasets import Dataset
from contextlib import nullcontext
import falp_policy as customPolicy
from modeling_falp import GPT2Config, GPT2LMHeadModel

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", default='train', type=str)
    parser.add_argument("--parallel", default='TP', type=str, help='DP or TP')
    parser.add_argument("--tp_size", default=2,  type=int)
    parser.add_argument("--dataset", default='openwebtext', type=str, help='wikitext or openwebtext')
    parser.add_argument("--model", default='GPT2', type=str, help='GPT2, BERT, Llama2')
    parser.add_argument("--model_name", default='GPT2-L', type=str, help='GPT2-B,M,L,XL or BERT-B,L or Llama2-7B')
    parser.add_argument("--epoch", default=1, type=int)
    parser.add_argument("--gradient_accumulation", action='store_true')
    parser.add_argument("--gradient_accumulation_value", default=4, type=int)
    parser.add_argument("--gradient_clipping", action='store_true')
    parser.add_argument("--batch", default=8, type=int)
    parser.add_argument("--max_seqlength", default=1024, type=int)
    parser.add_argument("--lr", default=1e-4, type=float)
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--hidden_dropout", default=0.0, type=float)
    parser.add_argument("--attn_dropout", default=0.0, type=float)
    parser.add_argument("--eps", default=1e-6, type=float)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--num_workers", default=8, type=int)
    parser.add_argument("--wandb", action='store_true')
    parser.add_argument("--project_name", default='GPT2-L-openwebtext', type=str)
    parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument("--output_dir", default='./checkpoints-L-overparm', type=str, help="Directory to save checkpoints")
    parser.add_argument("--start_batch", default=0, type=int, help="start batch")
    args = parser.parse_args()
    args.device = torch.device(f'cuda')
    return args

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def save_dataset(dataset, path):
    dataset.save_to_disk(path)

def load_dataset_from_disk(path):
    return Dataset.load_from_disk(path)

def filter_rows(example):
    return example['text'] != ''

def load_dataset(args):
    train_dataset_path = f"/mnt/dataset/huggingface/{args.dataset}_train_seq:{args.max_seqlength}_dynamic:False"
    val_dataset_path = f"/mnt/dataset/huggingface/{args.dataset}_val_seq:{args.max_seqlength}_dynamic:False"
    if os.path.exists(train_dataset_path) and os.path.exists(val_dataset_path):
        train_tokenized = load_dataset_from_disk(train_dataset_path)
        eval_tokenized = load_dataset_from_disk(val_dataset_path)
        total_text_num = len(train_tokenized)
    else:
        raise NotImplementedError("the dataset path doesn't exist")

    return train_tokenized, eval_tokenized, total_text_num

def create_dataloaders(args, train_dataset, eval_dataset, data_collator, plugin):
    bs = args.batch // (args.world_size//plugin.tp_size)
    train_loader = plugin.prepare_dataloader(train_dataset, batch_size=bs, shuffle=True, seed=args.seed, num_workers=args.num_workers, drop_last=False, collate_fn=data_collator)
    val_loader = plugin.prepare_dataloader(eval_dataset, batch_size=bs, shuffle=True, seed=args.seed, num_workers=args.num_workers, drop_last=False, collate_fn=data_collator)

    return train_loader, val_loader

def create_model(args):
    if args.model == 'GPT2':
        args.activation = 'gelu_new'
        model_configs = {
            'GPT2-B': {'num_layer': 12, 'num_head': 12, 'hidden_dim': 768},
            'GPT2-M': {'num_layer': 24, 'num_head': 16, 'hidden_dim': 1024},
            'GPT2-L': {'num_layer': 36, 'num_head': 20, 'hidden_dim': 1280},
            'GPT2-L2': {'num_layer': 40, 'num_head': 20, 'hidden_dim': 1400},
            'GPT2-XL': {'num_layer': 48, 'num_head': 24, 'hidden_dim': 1584} # original: 48, 25, 1600
        }
        if args.model_name in model_configs:
            config = model_configs[args.model_name]
            args.num_layer = config['num_layer']
            args.num_head = config['num_head']
            args.hidden_dim = config['hidden_dim']
        else:
            raise ValueError(f"Unknown model_name {args.model_name} for GPT2")

        configuration = GPT2Config(
            n_positions=args.max_seqlength,
            n_embd=args.hidden_dim,
            n_layer=args.num_layer,
            n_head=args.num_head,
            activation_function=args.activation,
            resid_pdrop=args.hidden_dropout,
            attn_pdrop=args.attn_dropout
        )
        model = GPT2LMHeadModel(configuration)
    else:
        raise NotImplementedError(f"Model {args.model} not implemented")
    return model

def create_optimizer(args, model):
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.eps)
    return optimizer

def get_latest_checkpoint(output_dir, prefix):
    checkpoint_files = [f for f in os.listdir(output_dir) if f.startswith(prefix)]
    if not checkpoint_files:
        return None, None
    latest_file = max(checkpoint_files, key=lambda x: int(x.split("_batch_idx")[-1]))
    return os.path.join(output_dir, latest_file), int(latest_file.split("_batch_idx")[-1])



def main():
    args = parse_arguments()
    if torch.cuda.current_device() == 0:
        os.makedirs(args.output_dir, exist_ok=True)
    set_seed(args.seed)
    os.environ['NCCL_IB_DISABLE'] = '1'
    os.environ['NCCL_P2P_LEVEL'] = 'LOC'
    colossalai.launch_from_torch()
    args.world_size = int(os.environ.get('WORLD_SIZE', 1))
    store_path = '/mnt/dataset/huggingface'

    # Load tokenizer
    if args.model == 'GPT2':
        tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2', cache_dir=store_path)
        tokenizer.pad_token = tokenizer.eos_token
    else:
        raise NotImplementedError(f"Model {args.model} not implemented")

    # Load or prepare dataset
    train_tokenized, eval_tokenized, total_text_num = load_dataset(args)

    # Set up data collator
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Set up plugin
    if args.parallel == 'TP':
        plugin = HybridParallelPlugin(tp_size=args.tp_size, pp_size=1, precision='bf16', max_norm = 1.0, custom_policy=customPolicy.GPT2LMHeadModelPolicy())
    else:
        raise ValueError(f"Unknown parallel mode {args.parallel}")

    # Create data loaders
    train_loader, val_loader = create_dataloaders(args, train_tokenized, eval_tokenized, data_collator, plugin)

    # Create model
    model = create_model(args)

    # Initialize booster
    booster = Booster(plugin=plugin)

    # Initialize wandb
    if torch.cuda.current_device() == 0 and args.wandb:
        ex_name = f"OverParm_{args.model_name}_TP{plugin.tp_size}_batch:{args.batch}X{args.gradient_accumulation_value}_lr:{args.lr}_BF16_GA"
        wandb.init(project=args.project_name, entity="YOUR ENTITY", name=ex_name)

    if args.mode == 'train':
        # Create optimizer
        optimizer = create_optimizer(args, model)
        # Boost model and optimizer
        model, optimizer, _, train_loader, _ = booster.boost(
            model=model,
            optimizer=optimizer,
            criterion=None,
            dataloader=train_loader,
            lr_scheduler=None
        )

        model_checkpoint_path, start_batch = get_latest_checkpoint(args.output_dir, "checkpoint_model_batch_idx")
        optimizer_checkpoint_path, _ = get_latest_checkpoint(args.output_dir, "checkpoint_optimizer_batch_idx")
        # Load model and optimizer if checkpoint exists
        if model_checkpoint_path:
            booster.load_model(model, model_checkpoint_path)
            print(f"Model checkpoint loaded from {model_checkpoint_path}.")
            args.start_batch = start_batch  # Update start_batch from checkpoint

        if optimizer_checkpoint_path:
            booster.load_optimizer(optimizer, optimizer_checkpoint_path)
            print(f"Optimizer checkpoint loaded from {optimizer_checkpoint_path}.")


        model = model.to(args.device)

        # Print model size and parameter count
        if torch.cuda.current_device() == 0:
            size_model = sum(p.numel() * p.element_size() for p in model.parameters())
            print(model)
            print(f"Model size: {size_model / 1e6:.2f} MB")
            print('Model Parameters:', sum(p.numel() for p in model.parameters()))

        # Training loop
        for epoch in range(args.start_epoch, args.epoch):
            start_time = time.time()
            train_one_epoch(model, train_loader, booster, optimizer, epoch, args.batch, total_text_num, args)
            print(f'\nEpoch {epoch} completed in {time.time() - start_time:.2f} seconds')
            # Optional: Validation step
            val_one_epoch(model, val_loader, epoch, args)
            booster.save_model(model, os.path.join(args.output_dir, f"checkpoint_model_epoch{epoch}"), shard = True, gather_dtensor=False)
            booster.save_optimizer(optimizer, os.path.join(args.output_dir, f"checkpoint_optimizer_epoch{epoch}"), shard = True, gather_dtensor=False)
    else:
        # Evaluation mode
        pass



def train_one_epoch(model, loader, booster, optimizer, epoch, batch_size, total_text_num, args):
    model.train()
    batch_idx = 0
    total_loss = 0
    optimizer.zero_grad()
    t1 = time.time()
    save_threshold_time =  11 * 3600 + 40 * 60  # 11 hours 40 minutes in seconds
    start_time = time.time()

    for batch in loader:
        if batch_idx < args.start_batch:
            batch_idx += 1
            continue        
        batch_idx += 1
        input_ids = batch['input_ids'].to(args.device)
        attention_mask = batch['attention_mask'].to(args.device)
        labels = batch['labels'].to(args.device)

        if args.gradient_accumulation:
            accumulation_step = (batch_idx - 1) % args.gradient_accumulation_value
            sync_context = booster.no_sync(model) if accumulation_step != 0 else nullcontext()
            with sync_context:
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                scaled_loss = loss / args.gradient_accumulation_value
                booster.backward(scaled_loss, optimizer)
                total_loss += loss.detach()  # Accumulate unscaled loss
            if accumulation_step == args.gradient_accumulation_value - 1:
                optimizer.step()
                optimizer.zero_grad()
                # Compute average loss over accumulation steps
                total_loss = total_loss / args.gradient_accumulation_value
                dist.reduce(total_loss, dst=0, op=dist.ReduceOp.SUM)
                if torch.cuda.current_device() == 0:
                    total_loss = total_loss / args.world_size
                    ppl = torch.exp(total_loss)
                    progress = (batch_idx / len(loader)) * 100
                    print(f'\r[Epoch {epoch}] Progress: {progress:.3f}%   Train Loss: {total_loss:.3f}   PPL: {ppl:.1f}   Time: {time.time() - t1:.3f}s', end='')
                    if args.wandb:
                        wandb.log({
                            "epoch": epoch,
                            "train loss": total_loss.item(),
                            "train PPL": ppl.item(),
                            "lr": optimizer.param_groups[0]['lr'],
                            "step": batch_idx - 1
                        })
                # Broadcast save condition from GPU1 (rank 1) to all GPUs
                # save_flag = torch.tensor([0], dtype=torch.int).to(args.device)
                # if dist.get_rank() == 1:
                #     if time.time() - start_time >= save_threshold_time:
                #         save_flag[0] = 1
                # dist.broadcast(save_flag, src=1)

                # if save_flag.item() == 1:
                #     model_checkpoint_dir = os.path.join(args.output_dir, f"checkpoint_model_batch_idx{batch_idx}")
                #     optimizer_checkpoint_dir = os.path.join(args.output_dir, f"checkpoint_optimizer_batch_idx{batch_idx}")

                #     # Save model and optimizer checkpoints (only GPU1 saves, others wait)
                #     booster.save_model(model, model_checkpoint_dir, shard=True, gather_dtensor=False)
                #     booster.save_optimizer(optimizer, optimizer_checkpoint_dir, shard=True, gather_dtensor=False)
                #     print("\nCheckpoint saved.")
                #     dist.barrier()  # Synchronize all processes
                #     exit()
                total_loss = 0
                t1 = time.time()

    print('\r', end='')

@torch.no_grad()
def val_one_epoch(model, loader, epoch, args):
    model.eval()
    loss_cnt = 0
    batch_idx = 0
    start_time = time.time()
    for batch in loader:
        batch_idx += 1
        input_ids = batch['input_ids'].to(args.device)
        attention_mask = batch['attention_mask'].to(args.device)
        labels = batch['labels'].to(args.device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss_cnt += loss.to(torch.float32)
        # break;
    total_loss = loss_cnt / batch_idx
    dist.reduce(total_loss, dst=0, op=dist.ReduceOp.SUM)
    if torch.cuda.current_device() == 0:
        total_loss /= args.world_size
        ppl = torch.exp(total_loss)
        if args.wandb:
            wandb.log({
                "epoch": epoch,
                "val loss": total_loss.item(),
                "val PPL": ppl.item()
            })
        print(f'\nValidation Loss: {total_loss:.3f}   PPL: {ppl:.1f}   Time: {time.time() - start_time:.3f}s', end='')

if __name__ == "__main__":
    main()
